from typing import Any, ClassVar, Dict, Optional, Type

import marshmallow
from marshmallow_oneofschema import OneOfSchema

from starkware.starkware_utils.validated_dataclass import ValidatedMarshmallowDataclass


class SchemaTracker:
    """
    Tracks a set of classes and provides a OneOfSchema which can be used to serialize them.
    """

    def __init__(self):
        self.classes: Dict[str, Any] = {}
        classes = self.classes

        class SchemaTrackerInner(OneOfSchema):
            """
            A OneOfSchema implementation that dynamically adds Schema classes of allowed
            (registered in advance) classes to its schema mapping.
            """

            type_schemas: Dict[str, Type[marshmallow.Schema]] = {}

            def get_data_type(self, data):
                data_type = super().get_data_type(data)
                if data_type not in self.type_schemas.keys():
                    self.type_schemas[data_type] = classes[data_type].Schema
                return data_type

            def get_obj_type(self, obj):
                name = super().get_obj_type(obj)
                obj_type = type(obj)
                assert (
                    name in classes.keys() and classes[name] is obj_type
                ), f"Trying to serialized the object {obj} that was not registered first."

                # We register the Schema object here, since it might not exists when the class
                # itself is registered.
                if name not in self.type_schemas.keys():
                    self.type_schemas[name] = obj.Schema

                return name

        self.Schema = SchemaTrackerInner

    def add_class(self, cls: type):
        cls_name = cls.__name__
        if cls_name in self.classes:
            assert (
                self.classes[cls_name] is cls
            ), f"Trying to register two classes with the same name {cls_name}"
        else:
            self.classes[cls_name] = cls


class SubclassSchemaTracker(ValidatedMarshmallowDataclass):
    """
    Tracks the subclasses of this class and includes them in its OneOfSchema options as a
    (de)serializer.

    A use case for this is to avoid explicitly defining a OneOfSchema class that knows all the
    derived classes and their schemas.

    Usage example:

        class Base(SubclassSchemaTracker):
            pass

        Base.track_subclasses()

        @marshmallow_dataclass.dataclass(frozen=True)
        class Derived(Base):
            pass

    Derived is added to the OneOfSchema generated by the tracker at the moment of inheritance.
    """

    subclass_schema_tracker: ClassVar[Optional[SchemaTracker]] = None

    @classmethod
    def track_subclasses(cls):
        """
        Creates a OneOfSchema schema for this class, and adds each subclass to this schema.
        """
        cls.subclass_schema_tracker = SchemaTracker()
        cls.Schema = cls.subclass_schema_tracker.Schema

    @classmethod
    def __init_subclass__(cls, **kwargs):
        """
        Registers the given cls class as a subclass of its first parent that called
        track_subclasses (if such a parent exists).
        """
        super().__init_subclass__(**kwargs)  # type: ignore[call-arg]

        if cls.subclass_schema_tracker is None:
            return

        cls.subclass_schema_tracker.add_class(cls)
